from enum import IntEnum
from enum import auto

from devicepilot.pylog.pylogger import PyLogger

from devicepilot.hw_abstraction.measurement_unit_ctrl import MeasurementUnit

from config_enum import eef_measurement_unit_enum as meas_enum

from meas_sequences.fi_sequences import fi_single_mode_sequence
from meas_sequences.fi_sequences import FIResultData
from meas_sequences.fi_sequences import FI_RESULT_BUFFER_SIZE
from meas_sequences.fi_sequences import FI_MEASUREMENT_WINDOW

from urpc.eefmeasurementfunctions import EEFMeasurementFunctions
from urpc.eefmeasurementparameter import EEFMeasurementParameter


class PMT(IntEnum):
    PMT1 = auto()
    PMT1FI = auto()
    PMT2 = auto()
    PMT2FI = auto()
    PMTUSLUM = auto()
    PMTHTSALPHA = auto()


class EEFMeasurementUnit(MeasurementUnit):
    """
    A virtual unit to load, execute and read results of measurements for the EEF board.
    """
    def __init__(self, hal_enum, endpoint_class=None, endpoint_class_sim=None):
        """
        Constructor. Reads out the config file and updates the values.

        Args:
            hal_enum: HAL config key for this measurement unit
            endpoint_class: The class to use for the EEF measurement endpoint. Used by subclasses to overwrite the default type.
            endpoint_class_sim: The class to use for the EEF measurement endpoint in simulation mode. Used by subclasses to overwrite the default type.
        """
        endpoint_class = EEFMeasurementFunctions if endpoint_class is None else endpoint_class
        endpoint_class_sim = None if endpoint_class_sim is None else endpoint_class_sim

        super().__init__(hal_enum, endpoint_class, endpoint_class_sim)
        self.endpoint: EEFMeasurementFunctions = self.endpoint  # Type hint for IDEs

        self.has_two_channel = False
        self.has_us_lum = False
        self.has_trf_laser = False
        self.has_interlock = False

        self.__has_std_alpha = False
        self.__has_hts_alpha = False

        self.dark_signal_per_ms_pmt1 = 0.0
        self.dark_signal_per_ms_pmt1_fi = 0.0
        self.dark_signal_per_ms_pmt2 = 0.0
        self.dark_signal_per_ms_pmt2_fi = 0.0
        self.dark_signal_per_ms_pmt_uslum = 0.0

        self.flash_power_low = None
        self.alpha_laser_status = None
        self.hts_alpha_laser_status = None

    @property
    def has_std_alpha(self):
        return self.__has_std_alpha

    @has_std_alpha.setter
    def has_std_alpha(self, value):
        self.__has_std_alpha = value
        if self.__has_std_alpha:
            self._parameter_keys[meas_enum.AlphaLaser.Power] = EEFMeasurementParameter.AlphaLaserPower
        else:
            self._parameter_keys.pop(meas_enum.AlphaLaser.Power, None)

    @property
    def has_hts_alpha(self):
        return self.__has_hts_alpha

    @has_hts_alpha.setter
    def has_hts_alpha(self, value):
        self.__has_hts_alpha = value
        if self.__has_hts_alpha:
            self._parameter_keys[meas_enum.HTSAlphaLaser.Power] = EEFMeasurementParameter.HTSAlphaLaserPower
        else:
            self._parameter_keys.pop(meas_enum.HTSAlphaLaser.Power, None)

    async def initialize_device(self):
        """
        Initializes the virtual unit, clears all measurements and uploads all firmware parameters.
        External hardware connected to the EEF is not initialized, to avoid making them unnecessarily mandatory.
        Higher level layers should do that using the provided functions like enable_pmt_hv().
        """
        await super().initialize_device()
        # Don't access external hardware here, like enable_pmt_hv(), and leave that to the higher level layers.

    async def load_fi_measurement(self, guid, measurement_time, dim_mode=False):
        frequency_of_flashes, number_of_flashes = self.get_frequency_and_number_of_flashes(measurement_time)
        charging_time_us = self.get_config(meas_enum.FlashUnit.ChargingTimeLowPower)
        analog_limit = self.get_config(meas_enum.PMT1.AnalogLimit) if dim_mode else -1

        sequence = fi_single_mode_sequence(number_of_flashes, frequency_of_flashes, charging_time_us, analog_limit)
        await self.load_measurement(guid, sequence, result_addresses=range(0, FI_RESULT_BUFFER_SIZE))

    async def read_fi_measurement_results(self, guid, measurement_time, iref0):
        """
        Waits for the FI measurement to finish, reads the result data and calculates the result values.

        Args:
            guid: The identifier of the measurement
            measurement_time: The measurement time in seconds
            iref0: The reference signal correction intensity
        Returns:
            The FI measurement results
        """
        results = await self.read_measurement_results(guid)
        return FIMeasurementResults(self, results, measurement_time, iref0)

    def get_frequency_and_number_of_flashes(self, measurement_time, flash_power_low=True):
        if flash_power_low:
            frequency_of_flashes = self.get_config(meas_enum.FlashUnit.FrequencyLowPower)
        else:
            frequency_of_flashes = self.get_config(meas_enum.FlashUnit.FrequencyHighPower)
        number_of_flashes = max(round(frequency_of_flashes * measurement_time / 1000), 1)
        return frequency_of_flashes, number_of_flashes

    async def enable_pmt_hv(self, enable=True):
        """
        Set the PMT settings for regular measurements and enable or disable the high voltage.
        This function sets the discriminator level and high voltage setting of PMT1 and PMT2.
        Call this function early during pre execute and avoid disabling the high voltage if not necessary.

        A delay of 1 second must be ensured when enabling the high voltage to allow the PMT to stabilize.
        When switching the high voltage setting without disabling it, a delay of 0.2 seconds is sufficient.

        Args:
            enable: True to enable the high voltage, False to disable it, defaults to True
        """
        enable = 1 if enable else 0

        dl = self.get_config(meas_enum.PMT1.DiscriminatorLevel)
        hv = self.get_config(meas_enum.PMT1.HighVoltageSetting)
        await self.endpoint.set_parameter(EEFMeasurementParameter.PMT1DiscriminatorLevel, dl)
        await self.endpoint.set_parameter(EEFMeasurementParameter.PMT1HighVoltageSetting, hv)
        await self.endpoint.set_parameter(EEFMeasurementParameter.PMT1HighVoltageEnable, enable)

        if self.has_two_channel:
            dl = self.get_config(meas_enum.PMT2.DiscriminatorLevel)
            hv = self.get_config(meas_enum.PMT2.HighVoltageSetting)
            await self.endpoint.set_parameter(EEFMeasurementParameter.PMT2DiscriminatorLevel, dl)
            await self.endpoint.set_parameter(EEFMeasurementParameter.PMT2HighVoltageSetting, hv)
            await self.endpoint.set_parameter(EEFMeasurementParameter.PMT2HighVoltageEnable, enable)

    async def enable_pmt_hv_fi(self, enable=True):
        """
        Set the PMT settings for FI measurements and enable or disable the high voltage.
        This function sets the discriminator level and high voltage setting of PMT1 and PMT2.
        Call this function early during pre execute and avoid disabling the high voltage if not necessary.

        A delay of 1 second must be ensured when enabling the high voltage to allow the PMT to stabilize.
        When switching the high voltage setting without disabling it, a delay of 0.2 seconds is sufficient.

        Args:
            enable: True to enable the high voltage, False to disable it, defaults to False
        """
        enable = 1 if enable else 0

        dl = self.get_config(meas_enum.PMT1.DiscriminatorLevel)
        hv = self.get_config(meas_enum.PMT1.HighVoltageSettingFI)
        await self.endpoint.set_parameter(EEFMeasurementParameter.PMT1DiscriminatorLevel, dl)
        await self.endpoint.set_parameter(EEFMeasurementParameter.PMT1HighVoltageSetting, hv)
        await self.endpoint.set_parameter(EEFMeasurementParameter.PMT1HighVoltageEnable, enable)

        if self.has_two_channel:
            dl = self.get_config(meas_enum.PMT2.DiscriminatorLevel)
            hv = self.get_config(meas_enum.PMT2.HighVoltageSettingFI)
            await self.endpoint.set_parameter(EEFMeasurementParameter.PMT2DiscriminatorLevel, dl)
            await self.endpoint.set_parameter(EEFMeasurementParameter.PMT2HighVoltageSetting, hv)
            await self.endpoint.set_parameter(EEFMeasurementParameter.PMT2HighVoltageEnable, enable)

    async def enable_pmt_hv_uslum(self, enable=True):
        """
        Set the PMT settings for USLUM measurements and enable or disable the high voltage.
        This function sets the discriminator level and high voltage setting of PMTUSLUM.
        Call this function early during pre execute and avoid disabling the high voltage if not necessary.

        A delay of 1 second must be ensured when enabling the high voltage to allow the PMT to stabilize.
        When switching the high voltage setting without disabling it, a delay of 0.2 seconds is sufficient.

        Args:
            enable: True to enable the high voltage, False to disable it, defaults to False
        """
        enable = 1 if enable else 0

        dl = self.get_config(meas_enum.PMTUSLUM.DiscriminatorLevel)
        hv = self.get_config(meas_enum.PMTUSLUM.HighVoltageSetting)
        await self.endpoint.set_parameter(EEFMeasurementParameter.PMTUSLUMDiscriminatorLevel, dl)
        await self.endpoint.set_parameter(EEFMeasurementParameter.PMTUSLUMHighVoltageSetting, hv)
        await self.endpoint.set_parameter(EEFMeasurementParameter.PMTUSLUMHighVoltageEnable, enable)

    async def enable_pmt_hv_hts_alpha(self, enable=True):
        """
        Set the PMT settings for HTS Alpha measurements and enable or disable the high voltage.
        This function sets the discriminator level and high voltage setting of PMTUSLUM.
        Call this function early during pre execute and avoid disabling the high voltage if not necessary.

        A delay of 1 second must be ensured when enabling the high voltage to allow the PMT to stabilize.
        When switching the high voltage setting without disabling it, a delay of 0.2 seconds is sufficient.

        Args:
            enable: True to enable the high voltage, False to disable it, defaults to False
        """
        enable = 1 if enable else 0

        dl = self.get_config(meas_enum.PMTHTSALPHA.DiscriminatorLevel)
        hv = self.get_config(meas_enum.PMTHTSALPHA.HighVoltageSetting)
        await self.endpoint.set_parameter(EEFMeasurementParameter.PMTUSLUMHighVoltageEnable, enable)
        await self.endpoint.set_parameter(EEFMeasurementParameter.PMTUSLUMHighVoltageSetting, hv)
        await self.endpoint.set_parameter(EEFMeasurementParameter.PMTUSLUMDiscriminatorLevel, dl)

    async def enable_flash_lamp_power(self, use_low_power):
        PyLogger.logger.info(f"{self.name} - enable_flash_lamp_power() - use_low_power: {use_low_power}")

        if use_low_power:
            flash_lamp_power = self.get_config(meas_enum.FlashUnit.VoltageSettingLowPower)
            high_power_enable = 0
        else:
            flash_lamp_power = self.get_config(meas_enum.FlashUnit.VoltageSettingHighPower)
            high_power_enable = 1

        await self.endpoint.set_parameter(EEFMeasurementParameter.FlashLampPower, flash_lamp_power)
        await self.endpoint.set_parameter(EEFMeasurementParameter.FlashLampHighPowerEnable, high_power_enable)
        self.flash_power_low = use_low_power

    async def enable_alpha_laser(self, enable=True):
        if enable:
            await self.endpoint.set_parameter(EEFMeasurementParameter.AlphaLaserEnable, 1)
            self.alpha_laser_status = True
        else:
            await self.endpoint.set_parameter(EEFMeasurementParameter.AlphaLaserEnable, 0)
            self.alpha_laser_status = False

    def is_alpha_laser_power_enabled(self):
        return self.alpha_laser_status

    async def enable_hts_alpha_laser(self, enable=True):
        if enable:
            await self.endpoint.set_parameter(EEFMeasurementParameter.HTSAlphaLaserEnable, 1)
            self.hts_alpha_laser_status = True
        else:
            await self.endpoint.set_parameter(EEFMeasurementParameter.HTSAlphaLaserEnable, 0)
            self.hts_alpha_laser_status = False

    def is_hts_alpha_laser_power_enabled(self):
        return self.hts_alpha_laser_status

    def counting_correction(self, pmt: PMT, measurement_window: float, counts: int, analog_counts: int = 0):
        """
        Calculate corrected counting value.
        Only the counting part of signal is corrected and dark signal is substracted.

        Args:
            pmt: The PMT used for the counting measurement
            total_meas_window_s: The total measurement window in seconds
            counts: The measured counting value to be corrected
            analog_counts: The measured analog value converted to counts, defaults to 0
        Returns:
            The corrected counting value
        """
        if pmt == PMT.PMT1:
            ppr = self.get_config(meas_enum.PMT1.PulsePairResolution)
            dark_signal_per_ms = self.dark_signal_per_ms_pmt1
        elif pmt == PMT.PMT1FI:
            ppr = self.get_config(meas_enum.PMT1.PulsePairResolution)
            dark_signal_per_ms = self.dark_signal_per_ms_pmt1_fi
        elif pmt == PMT.PMT2:
            ppr = self.get_config(meas_enum.PMT2.PulsePairResolution)
            dark_signal_per_ms = self.dark_signal_per_ms_pmt2
        elif pmt == PMT.PMT2FI:
            ppr = self.get_config(meas_enum.PMT2.PulsePairResolution)
            dark_signal_per_ms = self.dark_signal_per_ms_pmt2_fi
        elif pmt == PMT.PMTUSLUM:
            ppr = self.get_config(meas_enum.PMTUSLUM.PulsePairResolution)
            dark_signal_per_ms = self.dark_signal_per_ms_pmt_uslum
        elif pmt == PMT.PMTHTSALPHA:
            ppr = self.get_config(meas_enum.PMTHTSALPHA.PulsePairResolution)
            dark_signal_per_ms = self.dark_signal_per_ms_pmt_uslum
        else:
            raise ValueError(f"Invalid PMT: {pmt}")

        dark_signal = round(self.get_config(meas_enum.MeasurementUnit.DarkSignalFactor) * dark_signal_per_ms * measurement_window * 1000)

        if counts < 1.0e-20 or counts + analog_counts < 1.0e-20:
            corrected_counts = analog_counts - dark_signal
            if corrected_counts < 1.0e-20:
                corrected_counts = 0.0
            return corrected_counts

        cnt_meas_window_s = (counts / (counts + analog_counts)) * measurement_window

        if cnt_meas_window_s < 1.0e-20:
            corrected_counts = analog_counts - dark_signal
            if corrected_counts < 1.0e-20:
                corrected_counts = 0.0
            return corrected_counts

        if (1 - (counts / cnt_meas_window_s) * ppr) < 1.0e-20:
            PyLogger.logger.error(f"{self.name} - counting_correction() - Count rate too high for correction.. Probably saturated or wrong setting...")
            corrected_counts = 0.0
        else:
            corrected_counts = counts / (1 - (counts / cnt_meas_window_s) * ppr) + analog_counts - dark_signal

        if corrected_counts < 1.0e-20:
            corrected_counts = 0.0

        return corrected_counts

    def fi_analog_correction(self, pmt: PMT, number_of_flashes, analog_counts):

        if pmt == PMT.PMT1FI:
            fi_linearity_array = self.get_config(meas_enum.PMT1.LinearityCorrectionFI)
            coeff_array = {fi_linearity_array[i]: fi_linearity_array[i + 1] for i in range(0, len(fi_linearity_array), 2)}
        elif pmt == PMT.PMT2FI:
            fi_linearity_array = self.get_config(meas_enum.PMT2.LinearityCorrectionFI)
            coeff_array = {fi_linearity_array[i]: fi_linearity_array[i + 1] for i in range(0, len(fi_linearity_array), 2)}
        else:
            raise ValueError(f"Invalid PMT: {pmt}")

        analog_counts_per_flash = analog_counts / number_of_flashes

        low_limit = None
        high_limit = None

        for key in coeff_array:
            if key > analog_counts_per_flash:
                high_limit = key
                break
            low_limit = key

        if low_limit is None:
            PyLogger.logger.warning(f"{self.name} - fi_analog_correction() - Too high or low analog signal for linearity correction: {analog_counts}")
            return analog_counts

        high_limit_coeff = None

        if high_limit:
            high_limit_coeff = coeff_array[high_limit]
        low_limit_coeff = coeff_array[low_limit]

        if high_limit is None:
            coefficient = low_limit_coeff
        else:
            coefficient = ((high_limit_coeff - low_limit_coeff)
                           * (analog_counts_per_flash - low_limit)
                           / (high_limit - low_limit)
                           + low_limit_coeff)

        corrected_value = coefficient * analog_counts

        return corrected_value


class FIMeasurementResults():
    """
    Class to calculate and hold the results of a FI measurement.

    Attributes:
        ref_signal: The reference diode signal
        pmt1_signal: The PMT1 signal after IRef0 correction and normalization
        pmt2_signal: The PMT2 signal after IRef0 correction and normalization

        _result_data: The raw result data of the measurement
        _pmt1_analog_total: The PMT1 analog signal scaled to counting values
        _pmt2_analog_total: The PMT2 analog signal scaled to counting values
        _pmt1_analog_corrected: The PMT1 analog signal after FI linearity correction
        _pmt2_analog_corrected: The PMT2 analog signal after FI linearity correction
        _pmt1_total: The sum of the PMT1 analog signal and counting signal after the counting correction
        _pmt2_total: The sum of the PMT2 analog signal and counting signal after the counting correction
    """
    def __init__(self, meas_unit: EEFMeasurementUnit, results, measurement_time, iref0):
        """
        Args:
            meas_unit: The measurement unit instance
            results: The results of the measurement
            measurement_time: The measurement time in seconds
            iref0: The reference signal correction intensity
        """
        data = FIResultData(results)
        self._result_data = data
        # Calculate the number of flashes and the total measurement window based on the measurement time and flash frequency
        _, number_of_flashes = meas_unit.get_frequency_and_number_of_flashes(measurement_time)
        measurement_window = FI_MEASUREMENT_WINDOW * number_of_flashes
        # The total PMT analog signal scaled to counting values
        self._pmt1_analog_total = (data.pmt1_analog_low + data.pmt1_analog_high * meas_unit.get_config(meas_enum.PMT1.AnalogHighRangeScale))
        self._pmt2_analog_total = (data.pmt2_analog_low + data.pmt2_analog_high * meas_unit.get_config(meas_enum.PMT2.AnalogHighRangeScale))
        self._pmt1_analog_total *= meas_unit.get_config(meas_enum.PMT1.AnalogCountingEquivalent)
        self._pmt2_analog_total *= meas_unit.get_config(meas_enum.PMT2.AnalogCountingEquivalent)
        # The analog signal after the FI linearity correction
        self._pmt1_analog_corrected = meas_unit.fi_analog_correction(PMT.PMT1FI, number_of_flashes, self._pmt1_analog_total)
        self._pmt2_analog_corrected = meas_unit.fi_analog_correction(PMT.PMT2FI, number_of_flashes, self._pmt2_analog_total)
        # The sum of the analog signal and the counting signal after the counting correction
        self._pmt1_total = meas_unit.counting_correction(PMT.PMT1FI, measurement_window, data.pmt1_counts, self._pmt1_analog_corrected)
        self._pmt2_total = meas_unit.counting_correction(PMT.PMT2FI, measurement_window, data.pmt2_counts, self._pmt2_analog_corrected)
        # The total reference diode signal scaled to analog low values
        self.ref_signal = data.ref_analog_low + data.ref_analog_high * meas_unit.get_config(meas_enum.PD1REF.AnalogHighRangeScale)
        # The final result after IRef0 correction and normalization
        self.pmt1_signal = self._pmt1_total * iref0 * number_of_flashes / self.ref_signal * meas_unit.get_config(meas_enum.PMT1.NormFactorFI)
        self.pmt2_signal = self._pmt2_total * iref0 * number_of_flashes / self.ref_signal * meas_unit.get_config(meas_enum.PMT2.NormFactorFI)
